#!/opt/conda/bin/python
from __future__ import print_function

import sys

import os
import json
import pickle
import sys
import traceback
import copy
import shutil
from glob import glob

from fairseq import distributed_utils, options

from train_driver import main


# These are the paths to where SageMaker mounts interesting things in your container.

prefix = '/opt/ml/'

input_path = prefix + 'input/data'
output_path = os.path.join(prefix, 'output')
model_path = os.path.join(prefix, 'model')
param_path = os.path.join(prefix, 'input/config/hyperparameters.json')

# This algorithm has a single channel of input data called 'training'. Since we run in
# File mode, the input files are copied to the directory specified here.
channel_name='training'
training_path = os.path.join(input_path, channel_name)


# convert True/False into action argparse options
_STORE_TRUE_ACTIONS = [
    'no-progress-bar',
    'fp16',
    'skip-invalid-size-inputs-valid-test',
    'fix-batches-to-gpus',
    'sentence-avg',
    'reset-optimizer',
    'reset-lr-scheduler',
    'no-save',
    'no-epoch-checkpoints',
    'cpu',
    'quiet',
    'output-word-probs',
    'output-word-stats',
    'no-early-stop',
    'unnormalized',
    'no-beamable-mm',
    'score-reference',
    'sampling',
    'print-alignment',
]

def convert_args_dict_to_list(args_dict):
    args = []
    for key, value in args_dict.items():
        if key in _STORE_TRUE_ACTIONS:
            if value == 'True':
                args.append('--{}'.format(key))
        else:
            args.append('--{}'.format(key))
            args.append(value)
    return args


# The function to execute the training.
def train():
    print('Starting the training.')
    try:
        # Read in any hyperparameters that the user passed with the training job
        with open(param_path, 'r') as tc:
            training_params = json.load(tc)
        
        print(training_params)
        
        training_args = convert_args_dict_to_list(training_params)
        print(training_args)
        
        # get args for FairSeq by converting the hyperparameters as if they
        # were command-line arguments
        argv_copy = copy.deepcopy(sys.argv)
        # some arguments are pre-defined for SageMaker
        sys.argv[1:] = [training_path] + training_args + ["--save-dir", model_path]
        # get options from FairSeq
        parser = options.get_training_parser()
        # get inference/serve params  
        #options.add_dataset_args(parser, gen=True)
        options.add_generation_args(parser)
        options.add_interactive_args(parser)
        
        args = options.parse_args_and_arch(parser)
        # remove the modifications we did in the command-line arguments
        sys.argv = argv_copy
        
        print(args)
        
        # copy dictionaries to the model dir as they
        # are needed by the model during serving
        dict_files = glob(os.path.join(training_path, "dict*.txt"))
        for src in dict_files:
            shutil.copy(src, model_path)
        
        # copy bpecodes if available
        bpecodes_file = glob(os.path.join(training_path, "bpecodes"))
        for src in bpecodes_file:
            shutil.copy(src, model_path)
        
        # main training function
        with open('/opt/ml/input/config/resourceconfig.json', 'r') as f:
            resource_config = json.load(f)
        hosts = resource_config['hosts']
        if len(hosts) > 1:
            from distributed_train import main as distributed_main
            
            distributed_main(args)
        elif args.distributed_world_size > 1:
            from multiprocessing_train import main as multiprocessing_main

            multiprocessing_main(args)
        else:
            main(args)

        print('Training complete.')
    except Exception as e:
        # Write out an error file. This will be returned as the failureReason in the
        # DescribeTrainingJob result.
        trc = traceback.format_exc()
        with open(os.path.join(output_path, 'failure'), 'w') as s:
            s.write('Exception during training: ' + str(e) + '\n' + trc)
        # Printing this causes the exception to be in the training job logs, as well.
        print('Exception during training: ' + str(e) + '\n' + trc, file=sys.stderr)
        # A non-zero exit code causes the training job to be marked as Failed.
        sys.exit(255)


if __name__ == '__main__':
    train()

    # A zero exit code causes the job to be marked a Succeeded.
    sys.exit(0)